package statechurn;

import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.util.logging.Level;
import java.util.logging.Logger;
import primitives.graph.*;
import java.util.*;
import translator.DOTWriter;
import java.util.HashSet;

public class KTails {

    private HashMap<String, Node> nodes;
    private int merges = 0;
    private int counter = 0;
    private Stack<Node> toremove;
    private Graph g;
    private HashMap<Transition, Boolean> visitedTs;
    private HashMap<String, HashMap<String, HashMap>> tailCache;
    private HashMap<String, HashMap<String, Boolean>> sameCache;
    private int k = 5;
    private int cacheMisses;
    private int cacheHits;

    public KTails(HashMap<String, Node> nodes, int k) {
        this.nodes = nodes;
        this.k = k;
        toremove = new Stack<Node>();

        g = new Graph();

        visitedTs = new HashMap<Transition, Boolean>();
        tailCache = new HashMap<String, HashMap<String, HashMap>>();
        sameCache = new HashMap<String, HashMap<String, Boolean>>();
    }

    public void deleteFromCaches(Node n) {
        tailCache.remove(n.getLabel());
        sameCache.remove(n.getLabel());



    }

    public boolean equivalent(HashMap<String, HashMap> trace1, HashMap<String, HashMap> trace2) {
        if (trace1.isEmpty()) {
            return false;
        }

        return (trace1.equals(trace2));
    }

    public void clearVisitedTs() {
        visitedTs = new HashMap<Transition, Boolean>();
    }

    public boolean canMerge(Node node1, Node node2) {

        if (node1 == node2) {
            return false;
        }

        String label1 = node1.getLabel();
        String label2 = node2.getLabel();
        if (label1.compareTo(label2) > 0) {
            label1 = node2.getLabel();
            label2 = node1.getLabel();
        }

        HashMap<String, Boolean> cHit = sameCache.get(label1);

        if (cHit == null) {
            cHit = new HashMap<String, Boolean>();
            sameCache.put(label1, cHit);

        }
        Boolean b = cHit.get(label2);
        if (b == null) {
            cacheMisses++;

            HashMap<String, HashMap> traces1 = tailCache.get(node1.getLabel());
            HashMap<String, HashMap> traces2 = tailCache.get(node2.getLabel());

            if (traces1 == null) {
                traces1 = tracesOf(node1,k);
                
                tailCache.put(node1.getLabel(), traces1);

            }
            if (traces2 == null) {
                traces2 = tracesOf(node2, k);
                tailCache.put(node2.getLabel(), traces2);
            }



            boolean res = equivalent(traces1, traces2);
            cHit.put(label2, res);
            return res;
        } else {
            cacheHits++;
            return b;
        }

















    }

    public HashMap<String, HashMap> tracesOf(Node node, int length) {

        Stack<ArrayList<Object>> stack = new Stack<ArrayList<Object>>();

        HashMap<String, HashMap> ret = new HashMap<String, HashMap>();

        stack.push(l(node, length, ret));
        while (!stack.isEmpty()) {
            ArrayList<Object> top = stack.pop();

            Node n = (Node) top.get(0);
            int l = (Integer) top.get(1);
            HashMap<String, HashMap<String, Object>> rtarg = (HashMap<String, HashMap<String, Object>>) top.get(2);



            if (l > 0) {
                Set<Transition> ts = n.getTransitionsAsT();
                for (Transition it : ts) {

                    Node dest = it.getDestinationNode();

                    HashMap<String, Object> target = rtarg.get(it.getLabel());
                    if (target == null) {
                        target = new HashMap<String, Object>();
                        rtarg.put(it.getLabel(), target);
                    }

                    stack.push(l(dest, l - 1, target));




                }

            }
        }
        return ret;

    }

    public Graph getGraph() {
        Graph g = new Graph();
        for (String k : nodes.keySet()) {
            Node n = nodes.get(k);
            g.addNode(n);
        }
        return g;
    }

	//TODO make sure this works.
    public boolean doStep() {
        Node na = null;
        Node nb = null;
        
        HashSet<MergePair> s = new HashSet<MergePair>();
        
        
        boolean brokeForMemory=  false;
        outerloop:
        for (String label1 : nodes.keySet()) {
            for (String label2 : nodes.keySet()) {
                if (!label1.equals(label2)) {
                    Node n1 = nodes.get(label1);
                    Node n2 = nodes.get(label2);
                    if (n1.refCount > 0 && n2.refCount > 0) {
                        if (canMerge(n1, n2)) {
                            s.add(new MergePair(n2,n1));
                            break;
                        }
                        
                        if(!s.isEmpty() && (double)(Runtime.getRuntime().totalMemory()) / Runtime.getRuntime().maxMemory() > 0.95){
                            brokeForMemory = true;
                            break outerloop;
                            
                        }

                    }
                }
            }
        }
        long start = System.currentTimeMillis();
        int oldSize = nodes.size();
        for(MergePair l : s){
            na = l.getLeft();
            nb = l.getRight();
            if (na != null && nb != null && na.refCount > 0 && nb.refCount > 0){
                replace(nb, na);
                pruneNodes();
            }
            
            double time = (System.currentTimeMillis() - start) / 1000;
            if (time > 5) {
                int size = nodes.size();
                int delta = oldSize - size;
                oldSize = size;
                double rate = (delta / (0.000001 + time));
                System.err.println(
                        String.format(
                        "[doStep] %d merges have been made in %.2fs [%.2f/sec], machine has %d nodes.", delta, time, rate, size));

                start = System.currentTimeMillis();
            }
            
        }
        return !s.isEmpty() || brokeForMemory;
    }

    public void replace(Node node, Node with) {
        //replace these with loops

        if (node == with) {
            return;
        }


        for (String l : nodes.keySet()) {
            Node n = nodes.get(l);

            if (n != null) {
                Set<Transition> transitions = n.getTransitionsAsT();
                for (Transition it : transitions) {

                    Node target = it.getDestinationNode();
                    String label = it.getLabel();

                    if (target == node) {
                        n.deleteTransition(it);
                        n.connect(with, label);
                    }
                }
            }

        }
        clearVisitedTs();

        Set<Transition> ts = node.getTransitionsAsT();

        for (Transition it : ts) {
            merge(with, it);
        }


        nodes.remove(with.getLabel());
        tailCache.remove(with.getLabel());


        if (!with.getLabel().equals("Start")) {
            with.setLabel("merge" + counter);
            counter++;
        }
        if (node.getLabel().equals("Start")) {
            with.setLabel("Start");
        }

        nodes.put(with.getLabel(), with);


        toremove.add(node);

        merges++;
    }

    public static ArrayList l(Object a, Object b) {
        ArrayList ret = new ArrayList();
        ret.add(a);
        ret.add(b);
        return ret;
    }

    public static ArrayList l(Object a, Object b, Object c) {
        ArrayList ret = new ArrayList();
        ret.add(a);
        ret.add(b);
        ret.add(c);
        return ret;
    }

    public void merge(Node _to, Transition _from) {
        Stack<ArrayList> stack = new Stack<ArrayList>();

        stack.push(l(_to, _from));
        while (!stack.isEmpty()) {
            List rs = stack.pop();
            Node to = (Node) rs.get(0);
            Transition from = (Transition) rs.get(1);
            if (visitedTs.get(from) != null) {
                continue;
            }

            visitedTs.put(from, true);

            if (!to.hasTransitionWithLabel(from.getLabel())) {
                to.connect(from.getDestinationNode(), from.getLabel());
            } else {
                try {
                    Transition t = to.transitionWithLabel(from.getLabel());
                    if (t != from) {
                        if (t.getDestinationNode() == from.getDestinationNode()) {
                        } else {
                            Set<Transition> ts = from.getDestinationNode().getTransitionsAsT();
                            for (Transition it : ts) {

                                stack.push(l(t.getDestinationNode(), it));
                            }
                        }
                    }
                } catch (TransitionNotFoundException e) {
                }
            }

        }
    }

    public void pruneNodes() {


        while (!toremove.isEmpty()) {

            Node it = toremove.pop();
            if (it == null) {
                continue;
            }

            if (it.refCount <= 0) {
                Set<Transition> ts = it.getTransitionsAsT();
                for (Transition t : ts) {
                    toremove.push(t.getDestinationNode());

                    it.deleteTransition(t);

                }

                nodes.remove(it.getLabel());
                deleteFromCaches(it);
            }
        }


        toremove.clear();

    }

    public Graph doKTails() {
        boolean go = true;
        int iter = 0;
        DOTWriter writer = new DOTWriter();
        int startSize = nodes.size();
        long startTime = System.currentTimeMillis();

        g = getGraph();
        dump(String.format("iter%d.dot", iter), writer.getRepresentation(g));

        long start = System.currentTimeMillis();
        int oldSize = nodes.size();
        while (go) {

            iter++;
            go = doStep();
            pruneNodes();



            double time = (System.currentTimeMillis() - start) / 1000;
            if (time > 5) {
                int size = nodes.size();
                int delta = oldSize - size;
                oldSize = size;
                double rate = (delta / (0.000001 + time));
                System.err.println(
                        String.format(
                        "[doKTails] %d merges have been made in %.2fs [%.2f/sec], machine has %d nodes.", delta, time, rate, size));

                start = System.currentTimeMillis();
            }
        }
        g = getGraph();
        double time = (System.currentTimeMillis() - startTime) / 1000.0;
        long delta = startSize - nodes.size();
        double rate = delta / (0.000001 + time);
        System.err.println(
                String.format(
                "%d merges have been made in %.2fs [%.2f/sec], machine has %d nodes.", delta, time, rate, nodes.size()));
        dump("final.dot", writer.getRepresentation(g));

        return g;

    }

    public void dump(String filename, String content) {
        BufferedWriter out = null;
        try {
            out = new BufferedWriter(new FileWriter(filename));
            out.write(content);
            out.close();
        } catch (IOException ex) {
            //Ignore exception since it's no big deal if we can't dump output.
        } finally {
            try {
                out.close();
            } catch (IOException ex) {
            }
        }
    }
}